#!/usr/bin/python

import code,pickle,sys,os,re,traceback,cPickle,glob
import matplotlib
if "DISPLAY" not in os.environ: matplotlib.use("AGG")
else: matplotlib.use("GTK")
import random as pyrandom
from optparse import OptionParser
from pylab import *
from scipy import stats
import ocrolib
import heapq
from ocrolib import dbtables,quant,utils,gatedmodel,lru,docproc,Record,mlp

def log_progress(fmt,*args):
    sys.stderr.write(fmt%args)
    sys.stderr.write("\033[K\r")

parser = OptionParser(usage="""
%prog [options] input.db dir

Perform training of gated models.
""")
parser.add_option("-C","--cachesize",help="number of characters to be cached (1M approx 8Gbyte)",type=int,default=1000000)
parser.add_option("-c","--cutoff",help="cutoff used for gating",type=float,default=0.75)
parser.add_option("-m","--model",help="starter model",default=None)
parser.add_option("-D","--display",help="display",action="store_true")
parser.add_option("-N","--limit",help="limit training",type=int,default=100000000)
parser.add_option("-S","--nsample",help="numer of samples for estimation",type=int,default=100000)
parser.add_option("-r","--rounds",help="mlp rounds",type=int,default=24)
parser.add_option("-n","--ntrain",help="ntrain",type=int,default=150000)
parser.add_option("-t","--table",help="table",default="chars")
parser.add_option("-T","--threshold",help="cutoff",type=float,default=0.75)
parser.add_option("-o","--output",help="ouptut file",default="gated.cmodel")
(options,args) = parser.parse_args()

if len(args)!=2:
    parser.print_help()
    sys.exit(0)

dbfile,cdir = args

table = options.table
ntrain = options.ntrain
ntest = 100000
nrounds = 16

print "opening db"
db = utils.chardb(dbfile,"chars")

if options.model:
    print "loading clasifier"
    model = ocrolib.load_component(args[1])
    if hasattr(model,'addGated') and hasattr(model,'coutputs'):
        gated = model
    else:
        gated = gatedmodel.GatedModel()
        gated.addGated(gatedmodel.AlwaysGate(),model)
    del model
else:
    gated = gatedmodel.GatedModel()
    preload = glob.glob(cdir+"/*.*model")
    print "preloading"," ".join(preload[:4]),("..." if len(preload)>4 else "")
    gatedmodel.load_gatedmodel(gated,preload,cutoff=options.cutoff)

print "creating output dir"
if not os.path.exists(cdir): os.mkdir(cdir)
assert os.path.isdir(cdir)

# get the list of all samples

print "loading"
classes = {}
for r in db.execute("select id,cls from chars limit %d"%options.limit):
    classes[r.id] = r.cls

ids = sorted(classes.keys())

# simple function to get database samples
# the LRU cache avoids unnecessary disk accesses

@lru.lru_cache(maxsize=options.cachesize)
def getrow(id):
    r = list(db.execute("select * from chars where id=?",(id,)))[0]
    image = utils.blob2image(r.image)/255.0
    rel = docproc.rel_geo_normalize(r.rel)
    return Record(cls=r.cls,image=image,rel=rel)

# predict all the samples in the database

print "predicting"
predictions = {}
total = 0
for i in ids:
    if i%1000==0: log_progress("%8d / %8d",i,len(ids))
    row = getrow(i)
    pred = gated.cclassify(row.image,geometry=row.rel)
    predictions[i] = pred
    total += 1

print "\ngetrow hits",getrow.hits,"misses",getrow.misses

# given the predictions and actual classes, compute the list of errors

def errors():
    errors = []
    for i in ids:
        if predictions[i]!=classes[i]:
            errors.append(i)
    return errors

bad = errors()
print "\nerrors",len(bad)

while 1:
    # pick a random sample from the misclassified samples

    center_id = pyrandom.sample(bad,1)[0]
    print "=== training",center_id,"==="
    row = getrow(center_id)
    center = gated.extract(row.image)

    # determine a cutoff in order to get approximately the desired number of
    # training samples

    print "sampling to determine cutoff"
    samples = pyrandom.sample(ids,min(len(ids),options.nsample))
    dists = [quant.dist(center,gated.extract(getrow(i).image)) for i in samples]
    if options.display:
        clf(); hist(dists); ginput(1,timeout=1)
    frac = ntrain*1.0/len(ids)
    cutoff = stats.scoreatpercentile(dists,per=100.0*frac)
    cutoffs = [stats.scoreatpercentile(dists,per=100.0*frac*f) for f in linspace(0.0,1.0,100)]
    cutoffs = array(cutoffs,'f')
    threshold = stats.scoreatpercentile(dists,per=frac*100.0*options.threshold)
    print "cutoff",cutoff,threshold

    # compute the gate

    gate = gatedmodel.DistanceGate(center,threshold)

    # now find all the samples that fall within this gate

    print "getting training sample"
    model = mlp.AutoMlpModel(max_rounds=options.rounds)
    total = 0
    testsamples = []
    for i in ids:
        row = getrow(i)
        v = gated.extract(row.image)
        d = quant.dist(center,v)
        if d<cutoff:
            model.cadd(row.image,row.cls,geometry=row.rel)
            total += 1
            if total%1000==0: log_progress("%8d",total)
        if gate.check(v):
            testsamples.append(i)
    print "\ngot",total,"samples"

    for progress in model.updateModel1():
        log_progress("%s",progress)

    # write out the model

    fname = cdir+"/%08d.cmodel"%center_id
    print "saving",fname
    ocrolib.save_component(fname,model)
    record = utils.Record(center_id=center_id,
                          center=center,
                          cutoff=cutoff,
                          cutoffs=cutoffs,
                          ntrain=total)

    # write out the corresponding info

    fname = cdir+"/%08d.info"%center_id
    with open(fname,"w") as stream:
        cPickle.dump(record,stream,2)

    # update the predictions of samples that are affected by this gated classifier

    print "predicting"
    gated.addGated(gate,model)

    for i in testsamples:
        row = getrow(i)
        predictions[i] = gated.cclassify(row.image,geometry=row.rel)

    # compute the new error rate

    old = len(bad)
    bad = errors()
    print "nerrors",len(bad),"old",old
